
import csv
import os

# Define BHM parameters
eta = 1
theta = .75

# Define Bassier parameters
eta_bassier = 0.8
beta = 6

# Function to calculate cross-wage elasticity
def calculate_cwe_bhm(s_leader, s_follower, eta, theta):
    numerator = (1 + eta) * s_follower * s_leader * (eta - theta) 
    denominator = (1 + (1 + eta) * (1 - s_follower)) * s_follower * (eta - theta) + theta * (eta + 1)
    return (numerator / denominator)

def calculate_cwe_bassier(eta_bassier, beta, s_leader, s_follower):
    term1 = (eta_bassier * beta * s_leader) / (1 + eta_bassier * beta)
    term2 = (beta * s_leader * beta * s_follower) / (
        (beta * (1-s_follower)) * (1 + (beta * (1-s_follower))) * (1 + eta_bassier * beta)
    )
    CWE = ((1 + eta_bassier * beta) / (1 + eta_bassier * beta * s_leader)) * (term1 + term2)
    return CWE

# Specify the input and output files
home_dir = "[Replace with your replication folder directory]/figures_tables/estimates"            # Directory for the output file
os.makedirs(home_dir, exist_ok=True)  # Create the directory if it doesn't exist
output_file = os.path.join(home_dir, "cross_wage_elasticities_theta_eta.csv")
input_file = os.path.join(home_dir,"cwe_estimates_pooled_for_plotting.csv")  # Name of the input file


# Read the input data, calculate elasticities, and save the results
with open(input_file, mode="r") as infile, open(output_file, mode="w", newline="") as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile)

    # Read header from input file
    header = next(reader)
    
    # Validate that the header has the expected columns
    if "pol_wage_bill_share" not in header or "np_wage_bill_share" not in header:
        raise ValueError("Input CSV file must have 'pol_wage_bill_share' and 'np_wage_bill_share' columns.")
    
    # Get the column indices
    leader_idx = header.index("pol_wage_bill_share")
    follower_idx = header.index("np_wage_bill_share")
    wbratio_idx = header.index("wb_share_ratio")
    pshare_idx = header.index("pol_emp_share")
    npshare_idx = header.index("np_emp_share")
    cwe = header.index("estimate")
    type = header.index("type")
    exp_group = header.index("exp_group")
    min95 = header.index("min95")
    max95 = header.index("max95")
    
    # Write header for the output file
    writer.writerow(["type", "exp_group", "pol_wage_bill_share", "np_wage_bill_share", "wb_share_ratio", "pol_emp_share", "np_emp_share", "estimate", "min95", "max95", "bhm_cwe", "bassier_cwe"])
    
    # Process each row in the input file
    for row in reader:
        try:
            type_col = str(row[type])
            exp_group_col = str(row[exp_group])
            s_leader = float(row[leader_idx])
            s_follower = float(row[follower_idx])
            wbratio_col = float(row[wbratio_idx])
            pshare_col = float(row[pshare_idx])
            npshare_col = float(row[npshare_idx])
            cwe_col = float(row[cwe])
            min95_col = float(row[min95])
            max95_col = float(row[max95])

            # Calculate the cross-wage elasticity
            cwe_bhm = calculate_cwe_bhm(s_leader, s_follower, eta, theta)
            cwe_bassier = calculate_cwe_bassier(eta_bassier, beta, s_leader, s_follower)
            # Write to the output file
            writer.writerow([type_col, exp_group_col, s_leader, s_follower, wbratio_col, pshare_col, npshare_col, cwe_col, min95_col, max95_col, cwe_bhm, cwe_bassier])
        except ValueError:
            print(f"Skipping row with invalid data: {row}")

print(f"Cross-wage elasticities have been written to {output_file}")
